#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Doris Configuration Management Module
Implements configuration loading, validation and management functionality
"""

import json
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

try:
    from dotenv import load_dotenv
except ImportError:
    load_dotenv = None


@dataclass
class DatabaseConfig:
    """Database connection configuration"""

    host: str = "localhost"
    port: int = 9030
    user: str = "root"
    password: str = ""
    database: str = "information_schema"
    charset: str = "UTF8"

    # FE HTTP API port for profile and other HTTP APIs
    fe_http_port: int = 8030
    
    # BE nodes configuration for external access
    # If be_hosts is empty, will use "show backends" to get BE nodes
    be_hosts: list[str] = field(default_factory=list)
    be_webserver_port: int = 8040

    # Connection pool configuration
    min_connections: int = 5
    max_connections: int = 20
    connection_timeout: int = 30
    health_check_interval: int = 60
    max_connection_age: int = 3600


@dataclass
class SecurityConfig:
    """Security configuration"""

    # Authentication configuration
    auth_type: str = "token"  # token, basic, oauth
    token_secret: str = "default_secret"
    token_expiry: int = 3600

    # SQL security configuration
    blocked_keywords: list[str] = field(
        default_factory=lambda: [
            "DROP",
            "DELETE",
            "TRUNCATE",
            "ALTER",
            "CREATE",
            "INSERT",
            "UPDATE",
            "GRANT",
            "REVOKE",
        ]
    )
    max_query_complexity: int = 100
    max_result_rows: int = 10000

    # Sensitive table configuration
    sensitive_tables: dict[str, str] = field(default_factory=dict)

    # Data masking configuration
    enable_masking: bool = True
    masking_rules: list[dict[str, Any]] = field(default_factory=list)


@dataclass
class PerformanceConfig:
    """Performance configuration"""

    # Query cache configuration
    enable_query_cache: bool = True
    cache_ttl: int = 300
    max_cache_size: int = 1000

    # Concurrency control configuration
    max_concurrent_queries: int = 50
    query_timeout: int = 300

    # Connection pool optimization configuration
    connection_pool_size: int = 20
    idle_timeout: int = 1800
    
    # Response content size limit (characters)
    max_response_content_size: int = 4096


@dataclass
class LoggingConfig:
    """Logging configuration"""

    level: str = "INFO"
    format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    file_path: str | None = None
    max_file_size: int = 10 * 1024 * 1024  # 10MB
    backup_count: int = 5

    # Audit log configuration
    enable_audit: bool = True
    audit_file_path: str | None = None


@dataclass
class MonitoringConfig:
    """Monitoring configuration"""

    # Metrics collection configuration
    enable_metrics: bool = True
    metrics_port: int = 3001
    metrics_path: str = "/metrics"

    # Health check configuration
    health_check_port: int = 3002
    health_check_path: str = "/health"

    # Alert configuration
    enable_alerts: bool = False
    alert_webhook_url: str | None = None


@dataclass
class DorisConfig:
    """Doris MCP Server complete configuration"""

    # Basic configuration
    server_name: str = "doris-mcp-server"
    server_version: str = "0.4.0"
    server_port: int = 3000
    transport: str = "stdio"
    
    # Temporary files configuration
    temp_files_dir: str = "tmp"  # Temporary files directory for Explain and Profile outputs

    # Sub-configuration modules
    database: DatabaseConfig = field(default_factory=DatabaseConfig)
    security: SecurityConfig = field(default_factory=SecurityConfig)
    performance: PerformanceConfig = field(default_factory=PerformanceConfig)
    logging: LoggingConfig = field(default_factory=LoggingConfig)
    monitoring: MonitoringConfig = field(default_factory=MonitoringConfig)

    # Custom configuration
    custom_config: dict[str, Any] = field(default_factory=dict)

    @classmethod
    def from_file(cls, config_path: str) -> "DorisConfig":
        """Load configuration from file"""
        config_file = Path(config_path)

        if not config_file.exists():
            raise FileNotFoundError(f"Configuration file does not exist: {config_path}")

        try:
            with open(config_file, encoding="utf-8") as f:
                if config_file.suffix.lower() == ".json":
                    config_data = json.load(f)
                else:
                    # Support other formats (like YAML)
                    raise ValueError(f"Unsupported configuration file format: {config_file.suffix}")

            return cls._from_dict(config_data)

        except Exception as e:
            raise ValueError(f"Failed to load configuration file: {e}")

    @classmethod
    def from_env(cls, env_file: str | None = None) -> "DorisConfig":
        """Load configuration from environment variables
        
        Args:
            env_file: .env file path, if None, search in the following order:
                     .env, .env.local, .env.production, .env.development
        """
        # Load .env file
        if load_dotenv is not None:
            if env_file:
                # Load specified .env file
                if Path(env_file).exists():
                    load_dotenv(env_file)
                    logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_file}")
                else:
                    logging.getLogger(__name__).warning(f"Environment configuration file does not exist: {env_file}")
            else:
                # Load .env files in priority order
                env_files = [".env", ".env.local", ".env.production", ".env.development"]
                for env_path in env_files:
                    if Path(env_path).exists():
                        load_dotenv(env_path)
                        logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_path}")
                        break
                else:
                    logging.getLogger(__name__).info("No .env configuration file found, using system environment variables")
        else:
            logging.getLogger(__name__).warning("python-dotenv not installed, cannot load .env files")

        config = cls()

        # Database configuration
        config.database.host = os.getenv("DORIS_HOST", config.database.host)
        config.database.port = int(os.getenv("DORIS_PORT", str(config.database.port)))
        config.database.user = os.getenv("DORIS_USER", config.database.user)
        config.database.password = os.getenv("DORIS_PASSWORD", config.database.password)
        config.database.database = os.getenv("DORIS_DATABASE", config.database.database)
        config.database.fe_http_port = int(os.getenv("DORIS_FE_HTTP_PORT", str(config.database.fe_http_port)))
        
        # BE nodes configuration
        be_hosts_env = os.getenv("DORIS_BE_HOSTS", "")
        if be_hosts_env:
            config.database.be_hosts = [host.strip() for host in be_hosts_env.split(",") if host.strip()]
        config.database.be_webserver_port = int(os.getenv("DORIS_BE_WEBSERVER_PORT", str(config.database.be_webserver_port)))

        # Connection pool configuration
        config.database.min_connections = int(
            os.getenv("DORIS_MIN_CONNECTIONS", str(config.database.min_connections))
        )
        config.database.max_connections = int(
            os.getenv("DORIS_MAX_CONNECTIONS", str(config.database.max_connections))
        )
        config.database.connection_timeout = int(
            os.getenv("DORIS_CONNECTION_TIMEOUT", str(config.database.connection_timeout))
        )
        config.database.health_check_interval = int(
            os.getenv("DORIS_HEALTH_CHECK_INTERVAL", str(config.database.health_check_interval))
        )
        config.database.max_connection_age = int(
            os.getenv("DORIS_MAX_CONNECTION_AGE", str(config.database.max_connection_age))
        )

        # Security configuration
        config.security.auth_type = os.getenv("AUTH_TYPE", config.security.auth_type)
        config.security.token_secret = os.getenv("TOKEN_SECRET", config.security.token_secret)
        config.security.token_expiry = int(
            os.getenv("TOKEN_EXPIRY", str(config.security.token_expiry))
        )
        config.security.max_result_rows = int(
            os.getenv("MAX_RESULT_ROWS", str(config.security.max_result_rows))
        )
        config.security.max_query_complexity = int(
            os.getenv("MAX_QUERY_COMPLEXITY", str(config.security.max_query_complexity))
        )
        config.security.enable_masking = (
            os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true"
        )

        # Performance configuration
        config.performance.enable_query_cache = (
            os.getenv("ENABLE_QUERY_CACHE", "true").lower() == "true"
        )
        config.performance.cache_ttl = int(
            os.getenv("CACHE_TTL", str(config.performance.cache_ttl))
        )
        config.performance.max_cache_size = int(
            os.getenv("MAX_CACHE_SIZE", str(config.performance.max_cache_size))
        )
        config.performance.max_concurrent_queries = int(
            os.getenv("MAX_CONCURRENT_QUERIES", str(config.performance.max_concurrent_queries))
            )
        config.performance.query_timeout = int(
            os.getenv("QUERY_TIMEOUT", str(config.performance.query_timeout))
        )
        config.performance.max_response_content_size = int(
            os.getenv("MAX_RESPONSE_CONTENT_SIZE", str(config.performance.max_response_content_size))
        )

        # Logging configuration
        config.logging.level = os.getenv("LOG_LEVEL", config.logging.level)
        config.logging.file_path = os.getenv("LOG_FILE_PATH", config.logging.file_path)
        config.logging.enable_audit = (
            os.getenv("ENABLE_AUDIT", str(config.logging.enable_audit).lower()).lower() == "true"
        )
        config.logging.audit_file_path = os.getenv("AUDIT_FILE_PATH", config.logging.audit_file_path)

        # Monitoring configuration
        config.monitoring.enable_metrics = (
            os.getenv("ENABLE_METRICS", "true").lower() == "true"
        )
        config.monitoring.metrics_port = int(
            os.getenv("METRICS_PORT", str(config.monitoring.metrics_port))
        )
        config.monitoring.health_check_port = int(
            os.getenv("HEALTH_CHECK_PORT", str(config.monitoring.health_check_port))
        )
        config.monitoring.enable_alerts = (
            os.getenv("ENABLE_ALERTS", str(config.monitoring.enable_alerts).lower()).lower() == "true"
        )
        config.monitoring.alert_webhook_url = os.getenv("ALERT_WEBHOOK_URL", config.monitoring.alert_webhook_url)

        # Server configuration
        config.server_name = os.getenv("SERVER_NAME", config.server_name)
        config.server_version = os.getenv("SERVER_VERSION", config.server_version)
        config.server_port = int(os.getenv("SERVER_PORT", str(config.server_port)))
        config.temp_files_dir = os.getenv("TEMP_FILES_DIR", config.temp_files_dir)

        return config

    @classmethod
    def _from_dict(cls, config_data: dict[str, Any]) -> "DorisConfig":
        """Create configuration object from dictionary"""
        config = cls()

        # Update basic configuration
        for key in ["server_name", "server_version", "server_port", "temp_files_dir"]:
            if key in config_data:
                setattr(config, key, config_data[key])

        # Update database configuration
        if "database" in config_data:
            db_config = config_data["database"]
            for key, value in db_config.items():
                if hasattr(config.database, key):
                    setattr(config.database, key, value)

        # Update security configuration
        if "security" in config_data:
            sec_config = config_data["security"]
            for key, value in sec_config.items():
                if hasattr(config.security, key):
                    setattr(config.security, key, value)

        # Update performance configuration
        if "performance" in config_data:
            perf_config = config_data["performance"]
            for key, value in perf_config.items():
                if hasattr(config.performance, key):
                    setattr(config.performance, key, value)

        # Update logging configuration
        if "logging" in config_data:
            log_config = config_data["logging"]
            for key, value in log_config.items():
                if hasattr(config.logging, key):
                    setattr(config.logging, key, value)

        # Update monitoring configuration
        if "monitoring" in config_data:
            mon_config = config_data["monitoring"]
            for key, value in mon_config.items():
                if hasattr(config.monitoring, key):
                    setattr(config.monitoring, key, value)

        # Custom configuration
        config.custom_config = config_data.get("custom", {})

        return config

    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary format"""
        return {
            "server_name": self.server_name,
            "server_version": self.server_version,
            "server_port": self.server_port,
            "temp_files_dir": self.temp_files_dir,
            "database": {
                "host": self.database.host,
                "port": self.database.port,
                "user": self.database.user,
                "password": "***",  # Hide password
                "database": self.database.database,
                "charset": self.database.charset,
                "fe_http_port": self.database.fe_http_port,
                "be_hosts": self.database.be_hosts,
                "be_webserver_port": self.database.be_webserver_port,
                "min_connections": self.database.min_connections,
                "max_connections": self.database.max_connections,
                "connection_timeout": self.database.connection_timeout,
                "health_check_interval": self.database.health_check_interval,
                "max_connection_age": self.database.max_connection_age,
            },
            "security": {
                "auth_type": self.security.auth_type,
                "token_secret": "***",  # Hide secret key
                "token_expiry": self.security.token_expiry,
                "blocked_keywords": self.security.blocked_keywords,
                "max_query_complexity": self.security.max_query_complexity,
                "max_result_rows": self.security.max_result_rows,
                "sensitive_tables": self.security.sensitive_tables,
                "enable_masking": self.security.enable_masking,
                "masking_rules": len(self.security.masking_rules),
            },
            "performance": {
                "enable_query_cache": self.performance.enable_query_cache,
                "cache_ttl": self.performance.cache_ttl,
                "max_cache_size": self.performance.max_cache_size,
                "max_concurrent_queries": self.performance.max_concurrent_queries,
                "query_timeout": self.performance.query_timeout,
                "connection_pool_size": self.performance.connection_pool_size,
                "idle_timeout": self.performance.idle_timeout,
                "max_response_content_size": self.performance.max_response_content_size,
            },
            "logging": {
                "level": self.logging.level,
                "format": self.logging.format,
                "file_path": self.logging.file_path,
                "max_file_size": self.logging.max_file_size,
                "backup_count": self.logging.backup_count,
                "enable_audit": self.logging.enable_audit,
                "audit_file_path": self.logging.audit_file_path,
            },
            "monitoring": {
                "enable_metrics": self.monitoring.enable_metrics,
                "metrics_port": self.monitoring.metrics_port,
                "metrics_path": self.monitoring.metrics_path,
                "health_check_port": self.monitoring.health_check_port,
                "health_check_path": self.monitoring.health_check_path,
                "enable_alerts": self.monitoring.enable_alerts,
                "alert_webhook_url": self.monitoring.alert_webhook_url,
            },
            "custom": self.custom_config,
        }

    def save_to_file(self, config_path: str):
        """Save configuration to file"""
        config_file = Path(config_path)
        config_file.parent.mkdir(parents=True, exist_ok=True)

        try:
            with open(config_file, "w", encoding="utf-8") as f:
                if config_file.suffix.lower() == ".json":
                    json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
                else:
                    raise ValueError(f"Unsupported configuration file format: {config_file.suffix}")

        except Exception as e:
            raise ValueError(f"Failed to save configuration file: {e}")

    def validate(self) -> list[str]:
        """Validate configuration validity"""
        errors = []

        # Validate database configuration
        if not self.database.host:
            errors.append("Database host address cannot be empty")

        if not (1 <= self.database.port <= 65535):
            errors.append("Database port must be in the range 1-65535")

        if not self.database.user:
            errors.append("Database username cannot be empty")

        if self.database.min_connections <= 0:
            errors.append("Minimum connections must be greater than 0")

        if self.database.max_connections <= self.database.min_connections:
            errors.append("Maximum connections must be greater than minimum connections")

        # Validate security configuration
        if self.security.auth_type not in ["token", "basic", "oauth"]:
            errors.append("Authentication type must be one of token, basic, or oauth")

        if self.security.token_expiry <= 0:
            errors.append("Token expiry time must be greater than 0")

        if self.security.max_query_complexity <= 0:
            errors.append("Maximum query complexity must be greater than 0")

        if self.security.max_result_rows <= 0:
            errors.append("Maximum result rows must be greater than 0")

        # Validate performance configuration
        if self.performance.cache_ttl <= 0:
            errors.append("Cache TTL must be greater than 0")

        if self.performance.max_concurrent_queries <= 0:
            errors.append("Maximum concurrent queries must be greater than 0")

        if self.performance.query_timeout <= 0:
            errors.append("Query timeout must be greater than 0")

        # Validate logging configuration
        if self.logging.level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
            errors.append("Log level must be one of DEBUG, INFO, WARNING, ERROR, or CRITICAL")

        if self.logging.max_file_size <= 0:
            errors.append("Maximum log file size must be greater than 0")

        if self.logging.backup_count < 0:
            errors.append("Log backup count cannot be negative")

        # Validate monitoring configuration
        if not (1 <= self.monitoring.metrics_port <= 65535):
            errors.append("Monitoring port must be in the range 1-65535")

        if not (1 <= self.monitoring.health_check_port <= 65535):
            errors.append("Health check port must be in the range 1-65535")

        return errors

    def get_connection_string(self) -> str:
        """Get database connection string (hide password)"""
        return f"mysql://{self.database.user}:***@{self.database.host}:{self.database.port}/{self.database.database}"

    def get_config_summary(self) -> dict[str, Any]:
        """Get configuration summary information"""
        return {
            "server": f"{self.server_name} v{self.server_version}",
            "database": f"{self.database.host}:{self.database.port}/{self.database.database}",
            "connection_pool": f"{self.database.min_connections}-{self.database.max_connections}",
            "security": {
                "auth_type": self.security.auth_type,
                "masking_enabled": self.security.enable_masking,
                "blocked_keywords_count": len(self.security.blocked_keywords),
            },
            "performance": {
                "cache_enabled": self.performance.enable_query_cache,
                "max_concurrent": self.performance.max_concurrent_queries,
                "query_timeout": self.performance.query_timeout,
            },
            "monitoring": {
                "metrics_enabled": self.monitoring.enable_metrics,
                "alerts_enabled": self.monitoring.enable_alerts,
            },
        }


class ConfigManager:
    """Configuration manager class"""

    def __init__(self, config: DorisConfig):
        self.config = config
        self.logger = logging.getLogger(__name__)

    def setup_logging(self):
        """Setup logging configuration"""
        # Configure root logger
        root_logger = logging.getLogger()
        root_logger.setLevel(getattr(logging, self.config.logging.level.upper()))

        # Clear existing handlers
        for handler in root_logger.handlers[:]:
            root_logger.removeHandler(handler)

        # Create formatter
        formatter = logging.Formatter(self.config.logging.format)

        # Console handler
        console_handler = logging.StreamHandler()
        console_handler.setFormatter(formatter)
        root_logger.addHandler(console_handler)

        # File handler (if configured)
        if self.config.logging.file_path:
            try:
                from logging.handlers import RotatingFileHandler

                file_handler = RotatingFileHandler(
                    self.config.logging.file_path,
                    maxBytes=self.config.logging.max_file_size,
                    backupCount=self.config.logging.backup_count,
                    encoding="utf-8",
                )
                file_handler.setFormatter(formatter)
                root_logger.addHandler(file_handler)
            except Exception as e:
                self.logger.warning(f"Failed to setup file logging: {e}")

        # Audit log handler (if configured)
        if self.config.logging.enable_audit and self.config.logging.audit_file_path:
            try:
                from logging.handlers import RotatingFileHandler

                audit_logger = logging.getLogger("audit")
                audit_handler = RotatingFileHandler(
                    self.config.logging.audit_file_path,
                    maxBytes=self.config.logging.max_file_size,
                    backupCount=self.config.logging.backup_count,
                    encoding="utf-8",
                )
                audit_handler.setFormatter(formatter)
                audit_logger.addHandler(audit_handler)
                audit_logger.setLevel(logging.INFO)
            except Exception as e:
                self.logger.warning(f"Failed to setup audit logging: {e}")

    def validate_config(self) -> bool:
        """Validate configuration"""
        errors = self.config.validate()
        if errors:
            self.logger.error("Configuration validation failed:")
            for error in errors:
                self.logger.error(f"  - {error}")
            return False

        self.logger.info("Configuration validation passed")
        return True

    def log_config_summary(self):
        """Log configuration summary"""
        summary = self.config.get_config_summary()
        self.logger.info("Configuration Summary:")
        self.logger.info(f"  Server: {summary['server']}")
        self.logger.info(f"  Database: {summary['database']}")
        self.logger.info(f"  Connection Pool: {summary['connection_pool']}")
        self.logger.info(f"  Security: {summary['security']}")
        self.logger.info(f"  Performance: {summary['performance']}")
        self.logger.info(f"  Monitoring: {summary['monitoring']}")


def create_default_config_file(config_path: str):
    """Create default configuration file"""
    config = DorisConfig()
    config.save_to_file(config_path)
    print(f"Default configuration file created: {config_path}")


# Example usage
if __name__ == "__main__":
    # Create default configuration
    config = DorisConfig()

    # Load from environment variables
    # config = DorisConfig.from_env()

    # Load from file
    # config = DorisConfig.from_file("config.json")

    # Validate configuration
    config_manager = ConfigManager(config)
    if config_manager.validate_config():
        config_manager.setup_logging()
        config_manager.log_config_summary()

        # Save configuration
        config.save_to_file("example_config.json")
        print("Configuration saved to example_config.json")
    else:
        print("Configuration validation failed")
